{
 "nbformat": 4,
 "nbformat_minor": 0,
 "metadata": {
  "colab": {
   "name": "Gated Linear Units and Variants",
   "provenance": [],
   "collapsed_sections": [],
   "toc_visible": true
  },
  "kernelspec": {
   "name": "python3",
   "display_name": "Python 3"
  },
  "accelerator": "GPU"
 },
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "AYV_dMVDxyc2"
   },
   "source": [
    "[![Github](https://img.shields.io/github/stars/labmlai/annotated_deep_learning_paper_implementations?style=social)](https://github.com/labmlai/annotated_deep_learning_paper_implementations)\n",
    "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/transformers/glu_variants/simple.ipynb)                    \n",
    "\n",
    "## Gated Linear Units and Variants\n",
    "\n",
    "This trains a simple [transformer](https://nn.labml.ai/transformers/) model for auto-regression.\n",
    "We try different variants for the [position-wise feedforward network](https://nn.labml.ai/transformers/feed_forward.html).\n",
    "\n",
    "Annotated trainer code is at [`simple.py`](https://nn.labml.ai/transformers/glu_variants/simple.html)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "AahG_i2y5tY9"
   },
   "source": [
    "Install the `labml-nn` package"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "id": "ZCzmCrAIVg0L",
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "outputId": "2de76edb-9911-496d-9f8c-281dad6f5680"
   },
   "source": [
    "!pip install labml-nn"
   ],
   "execution_count": 1,
   "outputs": [
    {
     "output_type": "stream",
     "text": [
      "Requirement already satisfied: labml-nn in /usr/local/lib/python3.6/dist-packages (0.4.82)\n",
      "Requirement already satisfied: labml>=0.4.97 in /usr/local/lib/python3.6/dist-packages (from labml-nn) (0.4.97)\n",
      "Requirement already satisfied: torch in /usr/local/lib/python3.6/dist-packages (from labml-nn) (1.7.0+cu101)\n",
      "Requirement already satisfied: einops in /usr/local/lib/python3.6/dist-packages (from labml-nn) (0.3.0)\n",
      "Requirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from labml-nn) (1.19.5)\n",
      "Requirement already satisfied: labml-helpers>=0.4.72 in /usr/local/lib/python3.6/dist-packages (from labml-nn) (0.4.73)\n",
      "Requirement already satisfied: gitpython in /usr/local/lib/python3.6/dist-packages (from labml>=0.4.97->labml-nn) (3.1.12)\n",
      "Requirement already satisfied: pyyaml in /usr/local/lib/python3.6/dist-packages (from labml>=0.4.97->labml-nn) (3.13)\n",
      "Requirement already satisfied: future in /usr/local/lib/python3.6/dist-packages (from torch->labml-nn) (0.16.0)\n",
      "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.6/dist-packages (from torch->labml-nn) (3.7.4.3)\n",
      "Requirement already satisfied: dataclasses in /usr/local/lib/python3.6/dist-packages (from torch->labml-nn) (0.8)\n",
      "Requirement already satisfied: gitdb<5,>=4.0.1 in /usr/local/lib/python3.6/dist-packages (from gitpython->labml>=0.4.97->labml-nn) (4.0.5)\n",
      "Requirement already satisfied: smmap<4,>=3.0.1 in /usr/local/lib/python3.6/dist-packages (from gitdb<5,>=4.0.1->gitpython->labml>=0.4.97->labml-nn) (3.0.5)\n"
     ],
     "name": "stdout"
    }
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "SE2VUQ6L5zxI"
   },
   "source": [
    "Imports"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "id": "0hJXx_g0wS2C"
   },
   "source": [
    "import dataclasses\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from labml import experiment\n",
    "from labml_nn.transformers.glu_variants.simple import Configs, Trainer"
   ],
   "execution_count": 2,
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Lpggo0wM6qb-"
   },
   "source": [
    "Create an experiment"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "id": "bFcr9k-l4cAg"
   },
   "source": [
    "experiment.create(name=\"glu_variants\")"
   ],
   "execution_count": 3,
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "-OnHLi626tJt"
   },
   "source": [
    "Initialize configurations"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "id": "Piz0c5f44hRo"
   },
   "source": [
    "conf = Configs()"
   ],
   "execution_count": 4,
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "wwMzCqpD6vkL"
   },
   "source": [
    "Set experiment configurations and assign a configurations dictionary to override configurations"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 17
    },
    "id": "e6hmQhTw4nks",
    "outputId": "77eca625-7205-49ea-f275-23f2710c4d84"
   },
   "source": [
    "experiment.configs(dataclasses.asdict(conf))"
   ],
   "execution_count": 5,
   "outputs": [
    {
     "output_type": "display_data",
     "data": {
      "text/html": [
       "<pre style=\"overflow-x: scroll;\"></pre>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {
      "tags": []
     }
    }
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "DHyNvXfnzeWQ"
   },
   "source": [
    "Create [`Trainer`](https://nn.labml.ai/transformers/glu_variants/simple.html)"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "id": "59ZeTv5SzcVe"
   },
   "source": [
    "trainer = Trainer(conf)"
   ],
   "execution_count": 6,
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "EvI7MtgJ61w5"
   },
   "source": [
    "Set PyTorch models for loading and saving"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "id": "GDlt7dp-5ALt"
   },
   "source": [
    "experiment.add_pytorch_models({'model': trainer.model})"
   ],
   "execution_count": 7,
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "KJZRf8527GxL"
   },
   "source": [
    "Start the experiment and run the training loop."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 255
    },
    "id": "aIAWo7Fw5DR8",
    "outputId": "18b8b334-f9e7-458b-f900-5828b4f9a5c8"
   },
   "source": [
    "with experiment.start():\n",
    "    trainer.train()"
   ],
   "execution_count": null,
   "outputs": [
    {
     "output_type": "display_data",
     "data": {
      "text/html": [
       "<pre style=\"overflow-x: scroll;\">\n",
       "<strong><span style=\"text-decoration: underline\">glu_variants</span></strong>: <span style=\"color: #208FFB\">86b773f65fc911ebb2ac0242ac1c0002</span>\n",
       "\t[dirty]: <strong><span style=\"color: #DDB62B\">\"\"</span></strong>\n",
       "<span style=\"color: #C5C1B4\"></span>\n",
       "<span style=\"color: #C5C1B4\">--------------------------------------------------</span><span style=\"color: #DDB62B\"><strong><span style=\"text-decoration: underline\"></span></strong></span>\n",
       "<span style=\"color: #DDB62B\"><strong><span style=\"text-decoration: underline\">LABML WARNING</span></strong></span>\n",
       "<span style=\"color: #DDB62B\"><strong><span style=\"text-decoration: underline\"></span></strong></span>LabML App Warning: <span style=\"color: #60C6C8\">empty_token: </span><strong>Please create a valid token at https://app.labml.ai.</strong>\n",
       "<strong>Click on the experiment link to monitor the experiment and add it to your experiments list.</strong><span style=\"color: #C5C1B4\"></span>\n",
       "<span style=\"color: #C5C1B4\">--------------------------------------------------</span>\n",
       "<span style=\"color: #208FFB\">Monitor experiment at </span><a href='https://app.labml.ai/run?uuid=86b773f65fc911ebb2ac0242ac1c0002' target='blank'>https://app.labml.ai/run?uuid=86b773f65fc911ebb2ac0242ac1c0002</a>\n",
       "<span style=\"color: #C5C1B4\">It is</span><strong> </strong><strong>t</strong><strong>h</strong><strong>e</strong><strong> </strong><strong>t</strong><strong>h</strong><strong>e</strong><strong> </strong><strong>t</strong><strong>h</strong><strong>e</strong><strong> </strong><strong>t</strong><strong>h</strong><strong>e</strong><strong> </strong><strong>t</strong><strong>h</strong><strong>e</strong><strong> </strong><strong>t</strong><strong>h</strong><strong>e</strong><strong> </strong>\n",
       "<span style=\"color: #C5C1B4\">It is</span><strong> </strong><strong>t</strong><strong>h</strong><strong>e</strong><strong> </strong><strong>t</strong><strong>h</strong><strong>e</strong><strong> </strong><strong>t</strong><strong>h</strong><strong>e</strong><strong> </strong><strong>t</strong><strong>h</strong><strong>e</strong><strong> </strong><strong>t</strong><strong>h</strong><strong>e</strong><strong> </strong><strong>t</strong><strong>h</strong><strong>e</strong><strong> </strong>\n",
       "<span style=\"color: #C5C1B4\">It is</span><strong> </strong><strong>t</strong><strong>h</strong><strong>e</strong><strong> </strong><strong>t</strong><strong>h</strong><strong>e</strong><strong> </strong><strong>t</strong><strong>h</strong><strong>e</strong><strong> </strong><strong>t</strong><strong>h</strong><strong>e</strong><strong> </strong><strong>t</strong><strong>h</strong><strong>e</strong><strong> </strong><strong>t</strong><strong>h</strong><strong>e</strong><strong> </strong>\n",
       "<span style=\"color: #C5C1B4\">It is</span><strong> </strong><strong>t</strong><strong>h</strong><strong>e</strong><strong> </strong><strong>t</strong><strong>h</strong><strong>e</strong><strong> </strong><strong>t</strong><strong>h</strong><strong>e</strong><strong> </strong><strong>t</strong><strong>h</strong><strong>e</strong><strong> </strong><strong>t</strong><strong>h</strong><strong>e</strong><strong> </strong><strong>t</strong><strong>h</strong><strong>e</strong><strong> </strong>\n",
       "<strong><span style=\"color: #DDB62B\">1,925,120:  </span></strong>Train:<span style=\"color: #C5C1B4\">   1%</span><span style=\"color: #208FFB\">  8,427,381ms  </span> loss.train: <strong> 2.42505</strong>  <span style=\"color: #208FFB\">8,427,381ms</span><span style=\"color: #D160C4\">  0:01m/ 11:40m  </span></pre>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {
      "tags": []
     }
    }
   ]
  }
 ]
}